{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 加法进位实验\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"https://github.com/JerrikEph/jerrikeph.github.io/raw/master/Learn2Carry.png\" width=650>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/remote-home/jjgong01/a3/envs/tf2.0/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import collections\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import layers, optimizers, datasets\n",
    "import os,sys,tqdm\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据生成\n",
    "我们随机在 `start->end`之间采样除整数对`(num1, num2)`，计算结果`num1+num2`作为监督信号。\n",
    "\n",
    "* 首先将数字转换成数字位列表 `convertNum2Digits`\n",
    "* 将数字位列表反向\n",
    "* 将数字位列表填充到同样的长度 `pad2len`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_data_batch(batch_size, start, end):\n",
    "    '''在(start, end)区间采样生成一个batch的整型的数据\n",
    "    Args :\n",
    "        batch_size: batch_size\n",
    "        start: 开始数值\n",
    "        end: 结束数值\n",
    "    '''\n",
    "    numbers_1 = np.random.randint(start, end, batch_size)\n",
    "    numbers_2 = np.random.randint(start, end, batch_size)\n",
    "    results = numbers_1 + numbers_2\n",
    "    return numbers_1, numbers_2, results\n",
    "\n",
    "def convertNum2Digits(Num):\n",
    "    '''将一个整数转换成一个数字位的列表,例如 133412 ==> [1, 3, 3, 4, 1, 2]\n",
    "    '''\n",
    "    strNum = str(Num)\n",
    "    chNums = list(strNum)\n",
    "    digitNums = [int(o) for o in strNum]\n",
    "    return digitNums\n",
    "\n",
    "def convertDigits2Num(Digits):\n",
    "    '''将数字位列表反向， 例如 [1, 3, 3, 4, 1, 2] ==> [2, 1, 4, 3, 3, 1]\n",
    "    '''\n",
    "    digitStrs = [str(o) for o in Digits]\n",
    "    numStr = ''.join(digitStrs)\n",
    "    Num = int(numStr)\n",
    "    return Num\n",
    "\n",
    "def pad2len(lst, length, pad=0):\n",
    "    '''将一个列表用`pad`填充到`length`的长度 例如 pad2len([1, 3, 2, 3], 6, pad=0) ==> [1, 3, 2, 3, 0, 0]\n",
    "    '''\n",
    "    lst+=[pad]*(length - len(lst))\n",
    "    return lst\n",
    "\n",
    "def results_converter(res_lst):\n",
    "    '''将预测好的数字位列表批量转换成为原始整数\n",
    "    Args:\n",
    "        res_lst: shape(b_sz, len(digits))\n",
    "    '''\n",
    "    res = [reversed(digits) for digits in res_lst]\n",
    "    return [convertDigits2Num(digits) for digits in res]\n",
    "\n",
    "def prepare_batch(Nums1, Nums2, results, maxlen):\n",
    "    '''准备一个batch的数据，将数值转换成反转的数位列表并且填充到固定长度\n",
    "    Args:\n",
    "        Nums1: shape(batch_size,)\n",
    "        Nums2: shape(batch_size,)\n",
    "        results: shape(batch_size,)\n",
    "        maxlen:  type(int)\n",
    "    Returns:\n",
    "        Nums1: shape(batch_size, maxlen)\n",
    "        Nums2: shape(batch_size, maxlen)\n",
    "        results: shape(batch_size, maxlen)\n",
    "    '''\n",
    "    Nums1 = [convertNum2Digits(o) for o in Nums1]\n",
    "    Nums2 = [convertNum2Digits(o) for o in Nums2]\n",
    "    results = [convertNum2Digits(o) for o in results]\n",
    "    \n",
    "    Nums1 = [list(reversed(o)) for o in Nums1]\n",
    "    Nums2 = [list(reversed(o)) for o in Nums2]\n",
    "    results = [list(reversed(o)) for o in results]\n",
    "    \n",
    "    Nums1 = [pad2len(o, maxlen) for o in Nums1]\n",
    "    Nums2 = [pad2len(o, maxlen) for o in Nums2]\n",
    "    results = [pad2len(o, maxlen) for o in results]\n",
    "    \n",
    "    return Nums1, Nums2, results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 建模过程， 按照图示完成建模"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class myRNNModel(keras.Model):\n",
    "    def __init__(self):\n",
    "        super(myRNNModel, self).__init__()\n",
    "        self.embed_layer = tf.keras.layers.Embedding(10, 32, \n",
    "                                                    batch_input_shape=[None, None])\n",
    "        \n",
    "        self.rnncell = tf.keras.layers.SimpleRNNCell(64)\n",
    "        self.rnn_layer = tf.keras.layers.RNN(self.rnncell, return_sequences=True)\n",
    "        self.dense = tf.keras.layers.Dense(10)\n",
    "        \n",
    "    @tf.function\n",
    "    def call(self, num1, num2):\n",
    "        '''\n",
    "        此处完成上述图中模型\n",
    "        '''\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def compute_loss(logits, labels):\n",
    "    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            logits=logits, labels=labels)\n",
    "    return tf.reduce_mean(losses)\n",
    "\n",
    "@tf.function\n",
    "def train_one_step(model, optimizer, x, y, label):\n",
    "    with tf.GradientTape() as tape:\n",
    "        logits = model(x, y)\n",
    "        loss = compute_loss(logits, label)\n",
    "\n",
    "    # compute gradient\n",
    "    grads = tape.gradient(loss, model.trainable_variables)\n",
    "    optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
    "    return loss\n",
    "\n",
    "def train(steps, model, optimizer):\n",
    "    loss = 0.0\n",
    "    accuracy = 0.0\n",
    "    for step in range(steps):\n",
    "        datas = gen_data_batch(batch_size=200, start=0, end=555555555)\n",
    "        Nums1, Nums2, results = prepare_batch(*datas, maxlen=11)\n",
    "        loss = train_one_step(model, optimizer, tf.constant(Nums1, dtype=tf.int32), \n",
    "                              tf.constant(Nums2, dtype=tf.int32),\n",
    "                              tf.constant(results, dtype=tf.int32))\n",
    "        if step%50 == 0:\n",
    "            print('step', step, ': loss', loss.numpy())\n",
    "\n",
    "    return loss\n",
    "\n",
    "def evaluate(model):\n",
    "    datas = gen_data_batch(batch_size=2000, start=555555555, end=999999999)\n",
    "    Nums1, Nums2, results = prepare_batch(*datas, maxlen=11)\n",
    "    logits = model(tf.constant(Nums1, dtype=tf.int32), tf.constant(Nums2, dtype=tf.int32))\n",
    "    logits = logits.numpy()\n",
    "    pred = np.argmax(logits, axis=-1)\n",
    "    res = results_converter(pred)\n",
    "    for o in list(zip(datas[2], res))[:20]:\n",
    "        print(o[0], o[1], o[0]==o[1])\n",
    "\n",
    "    print('accuracy is: %g' % np.mean([o[0]==o[1] for o in zip(datas[2], res)]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optimizers.Adam(0.001)\n",
    "model = myRNNModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0 : loss 2.338743\n",
      "step 50 : loss 1.9090827\n",
      "step 100 : loss 1.8924072\n",
      "step 150 : loss 1.8782815\n",
      "step 200 : loss 1.8933363\n",
      "step 250 : loss 1.8919109\n",
      "step 300 : loss 1.8841321\n",
      "step 350 : loss 1.8789333\n",
      "step 400 : loss 1.8834051\n",
      "step 450 : loss 1.8841839\n",
      "step 500 : loss 1.8757857\n",
      "step 550 : loss 1.883435\n",
      "step 600 : loss 1.8695543\n",
      "step 650 : loss 1.8730036\n",
      "step 700 : loss 1.8712887\n",
      "step 750 : loss 1.8748087\n",
      "step 800 : loss 1.867813\n",
      "step 850 : loss 1.8665794\n",
      "step 900 : loss 1.8346387\n",
      "step 950 : loss 1.771845\n",
      "step 1000 : loss 1.66644\n",
      "step 1050 : loss 1.5232488\n",
      "step 1100 : loss 1.216567\n",
      "step 1150 : loss 0.98256004\n",
      "step 1200 : loss 0.8200839\n",
      "step 1250 : loss 0.69820565\n",
      "step 1300 : loss 0.5886905\n",
      "step 1350 : loss 0.50323135\n",
      "step 1400 : loss 0.42252934\n",
      "step 1450 : loss 0.36340263\n",
      "step 1500 : loss 0.2949791\n",
      "step 1550 : loss 0.24327374\n",
      "step 1600 : loss 0.20031683\n",
      "step 1650 : loss 0.16559783\n",
      "step 1700 : loss 0.13266504\n",
      "step 1750 : loss 0.11583002\n",
      "step 1800 : loss 0.09522024\n",
      "step 1850 : loss 0.08221528\n",
      "step 1900 : loss 0.07254601\n",
      "step 1950 : loss 0.06226328\n",
      "step 2000 : loss 0.054419417\n",
      "step 2050 : loss 0.049692407\n",
      "step 2100 : loss 0.043883115\n",
      "step 2150 : loss 0.038623303\n",
      "step 2200 : loss 0.03507858\n",
      "step 2250 : loss 0.031759493\n",
      "step 2300 : loss 0.028898347\n",
      "step 2350 : loss 0.026097616\n",
      "step 2400 : loss 0.024030704\n",
      "step 2450 : loss 0.022617567\n",
      "step 2500 : loss 0.020465054\n",
      "step 2550 : loss 0.018972643\n",
      "step 2600 : loss 0.017769149\n",
      "step 2650 : loss 0.016359653\n",
      "step 2700 : loss 0.015172582\n",
      "step 2750 : loss 0.014478247\n",
      "step 2800 : loss 0.013233788\n",
      "step 2850 : loss 0.012570047\n",
      "step 2900 : loss 0.011696805\n",
      "step 2950 : loss 0.01100141\n",
      "1314970279 1314970279 True\n",
      "1646630267 1646630267 True\n",
      "1765093362 1765093362 True\n",
      "1671475338 1671475338 True\n",
      "1467649664 1467649664 True\n",
      "1247588085 1247588085 True\n",
      "1396644726 1396644726 True\n",
      "1363359355 1363359355 True\n",
      "1617949833 1617949833 True\n",
      "1629931811 1629931811 True\n",
      "1452165053 1452165053 True\n",
      "1589216399 1589216399 True\n",
      "1696585095 1696585095 True\n",
      "1492344012 1492344012 True\n",
      "1398874117 1398874117 True\n",
      "1659389908 1659389908 True\n",
      "1808769780 1808769780 True\n",
      "1767070925 1767070925 True\n",
      "1252130370 1252130370 True\n",
      "1345605622 1345605622 True\n",
      "accuracy is: 1\n"
     ]
    }
   ],
   "source": [
    "train(3000, model, optimizer)\n",
    "evaluate(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
